Skip to content

Commit 5f5067e

Browse files
committed
support for image attachments when classifying questions
1 parent 42cc62f commit 5f5067e

File tree

2 files changed

+28
-3
lines changed

2 files changed

+28
-3
lines changed

kitsune/llm/questions/classifiers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def classify_question(question: "Question") -> dict[str, Any]:
3232
payload: dict[str, Any] = {
3333
"subject": question.title,
3434
"question": question.content,
35+
"image_urls": [image.get_absolute_url() for image in question.get_images()],
3536
"product": product,
3637
"topics": get_taxonomy(
3738
product, include_metadata=["description", "examples"], output_format="JSON"

kitsune/llm/questions/prompt.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1+
from typing import Any
2+
13
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
2-
from langchain.prompts import ChatPromptTemplate
4+
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
5+
from langchain.schema import HumanMessage
6+
from langchain.schema.runnable import RunnableLambda
37

48
SPAM_INSTRUCTIONS = """
59
# Role and goal
@@ -120,10 +124,10 @@
120124
)
121125

122126

123-
spam_prompt = ChatPromptTemplate(
127+
spam_prompt_with_human_message = ChatPromptTemplate(
124128
(
125129
("system", SPAM_INSTRUCTIONS),
126-
("human", USER_QUESTION),
130+
MessagesPlaceholder("human_message"),
127131
)
128132
).partial(format_instructions=spam_parser.get_format_instructions())
129133

@@ -134,3 +138,23 @@
134138
("human", USER_QUESTION),
135139
)
136140
).partial(format_instructions=topic_parser.get_format_instructions())
141+
142+
143+
def add_human_message(inputs: dict) -> dict:
144+
"""
145+
Adds the human message to the inputs dict and returns it. Ensures
146+
that the human message includes image URL's if they're present.
147+
"""
148+
image_urls = inputs.pop("image_urls", None)
149+
150+
content: list[dict[str, Any]] = [dict(type="text", text=USER_QUESTION.format(**inputs))]
151+
152+
if image_urls:
153+
for image_url in image_urls:
154+
content.append(dict(type="image_url", image_url=dict(url=image_url)))
155+
156+
inputs.update(human_message=[HumanMessage(content=content)])
157+
return inputs
158+
159+
160+
spam_prompt = RunnableLambda(add_human_message) | spam_prompt_with_human_message

0 commit comments

Comments
 (0)